{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fe6fd0ab0e1ad844",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# Transfer Learning with Sentence Transformers and Scikit-Learn"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8dcde44d942793ff",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## Introduction\n",
    "\n",
    "In this notebook, we will explore the process of transfer learning using SuperDuperDB. We will demonstrate how to connect to a MongoDB datastore, load a dataset, create a SuperDuperDB model based on Sentence Transformers, train a downstream model using Scikit-Learn, and apply the trained model to the database. Transfer learning is a powerful technique that can be used in various applications, such as vector search and downstream learning tasks."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1809feca8a8dca5a",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## Prerequisites\n",
    "\n",
    "Before diving into the implementation, ensure that you have the necessary libraries installed by running the following commands:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94f3219ad932a327",
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install superduperdb\n",
    "!pip install ipython numpy datasets sentence-transformers"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6bc151f6",
   "metadata": {},
   "source": [
    "## Connect to datastore "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5379007991707d17",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "First, we need to establish a connection to a MongoDB datastore via SuperDuperDB. You can configure the `MongoDB_URI` based on your specific setup. \n",
    "Here are some examples of MongoDB URIs:\n",
    "\n",
    "* For testing (default connection): `mongomock://test`\n",
    "* Local MongoDB instance: `mongodb://localhost:27017`\n",
    "* MongoDB with authentication: `mongodb://superduper:superduper@mongodb:27017/documents`\n",
    "* MongoDB Atlas: `mongodb+srv://<username>:<password>@<atlas_cluster>/<database>`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44f8ef76",
   "metadata": {},
   "outputs": [],
   "source": [
    "from superduperdb import superduper\n",
    "from superduperdb.backends.mongodb import Collection\n",
    "import os\n",
    "\n",
    "mongodb_uri = os.getenv(\"MONGODB_URI\", \"mongomock://test\")\n",
    "\n",
    "# SuperDuperDB, now handles your MongoDB database\n",
    "# It just super dupers your database\n",
    "db = superduper(mongodb_uri, artifact_store='filesystem://./data/')\n",
    "\n",
    "# Reference a collection called transfer\n",
    "collection = Collection('transfer')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "97fede97",
   "metadata": {},
   "source": [
    "## Load Dataset\n",
    "\n",
    "Transfer learning can be applied to any data that can be processed with SuperDuperDB models.\n",
    "For our example, we will use a labeled textual dataset with sentiment analysis.  We'll load a subset of the IMDb dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bb65106",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy\n",
    "from datasets import load_dataset\n",
    "from superduperdb import Document as D\n",
    "\n",
    "# Load IMDb dataset\n",
    "data = load_dataset(\"imdb\")\n",
    "\n",
    "# Set the number of data points for training (adjust as needed)\n",
    "N_DATAPOINTS = 500\n",
    "\n",
    "# Prepare training data\n",
    "train_data = [\n",
    "    D({'_fold': 'train', **data['train'][int(i)]})\n",
    "    for i in numpy.random.permutation(len(data['train']))\n",
    "][:N_DATAPOINTS]\n",
    "\n",
    "# Prepare validation data\n",
    "valid_data = [\n",
    "    D({'_fold': 'valid', **data['test'][int(i)]})\n",
    "    for i in numpy.random.permutation(len(data['test']))\n",
    "][:N_DATAPOINTS // 10]\n",
    "\n",
    "# Insert training data into the 'collection' SuperDuperDB collection\n",
    "db.execute(collection.insert_many(train_data))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "00a92214",
   "metadata": {},
   "source": [
    "## Run Model\n",
    "\n",
    "We'll create a SuperDuperDB model based on the `sentence_transformers` library. This demonstrates that you don't necessarily need a native SuperDuperDB integration with a model library to leverage its power. We configure the `Model wrapper` to work with the `SentenceTransformer class`. After configuration, we can link the model to a collection and daemonize the model with the `listen=True` keyword."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fef91c74",
   "metadata": {},
   "outputs": [],
   "source": [
    "from superduperdb import Model\n",
    "import sentence_transformers\n",
    "from superduperdb.ext.numpy import array\n",
    "\n",
    "# Create a SuperDuperDB Model using Sentence Transformers\n",
    "m = Model(\n",
    "    identifier='all-MiniLM-L6-v2',\n",
    "    object=sentence_transformers.SentenceTransformer('all-MiniLM-L6-v2'),\n",
    "    encoder=array('float32', shape=(384,)),\n",
    "    predict_method='encode',\n",
    "    batch_predict=True,\n",
    ")\n",
    "\n",
    "# Make predictions on 'text' data from the 'collection' SuperDuperDB collection\n",
    "m.predict(\n",
    "    X='text',\n",
    "    db=db,\n",
    "    select=collection.find(),\n",
    "    listen=True,\n",
    "    show_progress_bar=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aed95196-4454-4d88-beb8-dc6dddfcbe33",
   "metadata": {},
   "outputs": [],
   "source": [
    "db.execute(collection.find_one())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68fefc17",
   "metadata": {},
   "source": [
    "## Train Downstream Model\n",
    "Now that we've created and added the model that computes features for the `\"text\"`, we can train a downstream model using Scikit-Learn."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c2faeeb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import necessary modules and classes\n",
    "from sklearn.svm import SVC\n",
    "from superduperdb import superduper\n",
    "\n",
    "# Create a SuperDuperDB model with an SVC classifier\n",
    "model = superduper(\n",
    "    SVC(gamma='scale', class_weight='balanced', C=100, verbose=True),\n",
    "    postprocess=lambda x: int(x)\n",
    ")\n",
    "\n",
    "# Train the model on 'text' data with corresponding labels\n",
    "model.fit(\n",
    "    X='_outputs.text.all-MiniLM-L6-v2.0',\n",
    "    y='label',\n",
    "    db=db,\n",
    "    select=collection.find(),\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d1e1f164",
   "metadata": {},
   "source": [
    "## Run Downstream Model\n",
    "\n",
    "With the model trained, we can now apply it to the database. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eee16436",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make predictions on 'text' data with the trained SuperDuperDB model\n",
    "model.predict(\n",
    "    X='_outputs.text.all-MiniLM-L6-v2.0',\n",
    "    db=db,\n",
    "    select=collection.find(),\n",
    "    listen=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "67b156c1",
   "metadata": {},
   "source": [
    "## Verification\n",
    "\n",
    "To verify that the process has worked, we can sample a few records to inspect the sanity of the predictions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76958a1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Query a random document from the 'collection' SuperDuperDB collection\n",
    "r = next(db.execute(collection.aggregate([{'$sample': {'size': 1}}])))\n",
    "\n",
    "# Print a portion of the 'text' field from the random document\n",
    "print(r['text'][:100])\n",
    "\n",
    "# Print the prediction made by the SVC model stored in '_outputs'\n",
    "print(r['_outputs']['text']['svc']['0'])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
